{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinforcement Learning for seq2seq\n",
    "\n",
    "This time we'll solve a problem of transribing hebrew words in english, also known as g2p (grapheme2phoneme)\n",
    "\n",
    " * word (sequence of letters in source language) -> translation (sequence of letters in target language)\n",
    "\n",
    "Unlike what most deep learning researchers do, we won't only train it to maximize likelihood of correct translation, but also employ reinforcement learning to actually teach it to translate with as few errors as possible.\n",
    "\n",
    "\n",
    "### About the task\n",
    "\n",
    "One notable property of Hebrew is that it's consonant language. That is, there are no wovels in the written language. One could represent wovels with diacritics above consonants, but you don't expect people to do that in everyay life.\n",
    "\n",
    "Therefore, some hebrew characters will correspond to several english letters and others - to none, so we should use encoder-decoder architecture to figure that out.\n",
    "\n",
    "![img](https://esciencegroup.files.wordpress.com/2016/03/seq2seq.jpg)\n",
    "_(img: esciencegroup.files.wordpress.com)_\n",
    "\n",
    "Encoder-decoder architectures are about converting anything to anything, including\n",
    " * Machine translation and spoken dialogue systems\n",
    " * [Image captioning](http://mscoco.org/dataset/#captions-challenge2015) and [image2latex](https://openai.com/requests-for-research/#im2latex) (convolutional encoder, recurrent decoder)\n",
    " * Generating [images by captions](https://arxiv.org/abs/1511.02793) (recurrent encoder, convolutional decoder)\n",
    " * Grapheme2phoneme - convert words to transcripts\n",
    "  \n",
    "We chose simplified __Hebrew->English__ machine translation for words and short phrases (character-level), as it is relatively quick to train even without a gpu cluster."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# If True, only translates phrases shorter than 20 characters (way easier).",
    "\n",
    "EASY_MODE = True",
    "\n",
    "# Please keep it until you're done debugging your code",
    "\n",
    "# If false, works with all phrases (please switch to this mode for homework assignment)",
    "\n",
    "\n",
    "# way we translate. Either \"he-to-en\" or \"en-to-he\"",
    "\n",
    "MODE = \"he-to-en\"",
    "\n",
    "# maximal length of _generated_ output, does not affect training",
    "\n",
    "MAX_OUTPUT_LENGTH = 50 if not EASY_MODE else 20",
    "\n",
    "REPORT_FREQ = 100                          # how often to evaluate validation score"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 1: preprocessing\n",
    "\n",
    "We shall store dataset as a dictionary\n",
    "`{ word1:[translation1,translation2,...], word2:[...],...}`.\n",
    "\n",
    "This is mostly due to the fact that many words have several correct translations.\n",
    "\n",
    "We have implemented this thing for you so that you can focus on more interesting parts.\n",
    "\n",
    "\n",
    "__Attention python2 users!__ You may want to cast everything to unicode later during homework phase, just make sure you do it _everywhere_."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np",
    "\n",
    "from collections import defaultdict",
    "\n",
    "word_to_translation = defaultdict(list)  # our dictionary",
    "\n",
    "\n",
    "bos = '_'",
    "\n",
    "eos = ';'",
    "\n",
    "\n",
    "with open(\"main_dataset.txt\", encoding='utf8') as fin:",
    "\n",
    "    for line in fin:",
    "\n",
    "\n",
    "        en, he = line[:-1].lower().replace(bos, ' ').replace(eos,",
    "\n",
    "                                                             ' ').split('\\t')",
    "\n",
    "        word, trans = (he, en) if MODE == 'he-to-en' else (en, he)",
    "\n",
    "\n",
    "        if len(word) < 3:",
    "\n",
    "            continue",
    "\n",
    "        if EASY_MODE:",
    "\n",
    "            if max(len(word), len(trans)) > 20:",
    "\n",
    "                continue",
    "\n",
    "\n",
    "        word_to_translation[word].append(trans)",
    "\n",
    "\n",
    "print(\"size = \", len(word_to_translation))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get all unique lines in source language",
    "\n",
    "all_words = np.array(list(word_to_translation.keys()))",
    "\n",
    "# get all unique lines in translation language",
    "\n",
    "all_translations = np.array(",
    "\n",
    "    [ts for all_ts in word_to_translation.values() for ts in all_ts])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### split the dataset\n",
    "\n",
    "We hold out 10% of all words to be used for validation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sklearn.model_selection import train_test_split",
    "\n",
    "train_words, test_words = train_test_split(",
    "\n",
    "    all_words, test_size=0.1, random_state=42)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Building vocabularies\n",
    "\n",
    "We now need to build vocabularies that map strings to token ids and vice versa. We're gonna need these fellas when we feed training data into model or convert output matrices into english words."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from voc import Vocab",
    "\n",
    "inp_voc = Vocab.from_lines(''.join(all_words), bos=bos, eos=eos, sep='')",
    "\n",
    "out_voc = Vocab.from_lines(''.join(all_translations), bos=bos, eos=eos, sep='')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Here's how you cast lines into ids and backwards.",
    "\n",
    "batch_lines = all_words[:5]",
    "\n",
    "batch_ids = inp_voc.to_matrix(batch_lines)",
    "\n",
    "batch_lines_restored = inp_voc.to_lines(batch_ids)",
    "\n",
    "\n",
    "print(\"lines\")",
    "\n",
    "print(batch_lines)",
    "\n",
    "print(\"\\nwords to ids (0 = bos, 1 = eos):\")",
    "\n",
    "print(batch_ids)",
    "\n",
    "print(\"\\nback to words\")",
    "\n",
    "print(batch_lines_restored)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Draw word/translation length distributions to estimate the scope of the task."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import matplotlib.pyplot as plt",
    "\n",
    "%matplotlib inline",
    "\n",
    "plt.figure(figsize=[8, 4])",
    "\n",
    "plt.subplot(1, 2, 1)",
    "\n",
    "plt.title(\"words\")",
    "\n",
    "plt.hist(list(map(len, all_words)), bins=20)",
    "\n",
    "\n",
    "plt.subplot(1, 2, 2)",
    "\n",
    "plt.title('translations')",
    "\n",
    "plt.hist(list(map(len, all_translations)), bins=20)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Step 3: deploy encoder-decoder (1 point)\n",
    "\n",
    "__assignment starts here__\n",
    "\n",
    "Our architecture consists of two main blocks:\n",
    "* Encoder reads words character by character and outputs code vector (usually a function of last RNN state)\n",
    "* Decoder takes that code vector and produces translations character by character\n",
    "\n",
    "Than it gets fed into a model that follows this simple interface:\n",
    "* __`model.symbolic_translate(inp, **flags) -> out, logp`__ - takes symbolic int32 matrix of hebrew words, produces output tokens sampled from the model and output log-probabilities for all possible tokens at each tick.\n",
    "   * if given flag __`greedy=True`__, takes most likely next token at each iteration. Otherwise samples with next token probabilities predicted by model.\n",
    "* __`model.symbolic_score(inp, out, **flags) -> logp`__ - takes symbolic int32 matrices of hebrew words and their english translations. Computes the log-probabilities of all possible english characters given english prefices and hebrew word.\n",
    "\n",
    "That's all! It's as hard as it gets. With those two methods alone you can implement all kinds of prediction and training."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# set flags here if necessary",
    "\n",
    "import theano",
    "\n",
    "theano.config.floatX = 'float32'",
    "\n",
    "import theano.tensor as T",
    "\n",
    "import lasagne"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from basic_model_theano import BasicTranslationModel",
    "\n",
    "model = BasicTranslationModel(inp_voc, out_voc,",
    "\n",
    "                              emb_size=64, hid_size=128)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Play around with symbolic_translate and symbolic_score",
    "\n",
    "inp = T.constant(np.random.randint(0, 10, [3, 5], dtype='int32'))",
    "\n",
    "out = T.constant(np.random.randint(0, 10, [3, 5], dtype='int32'))",
    "\n",
    "\n",
    "# translate inp (with untrained model)",
    "\n",
    "sampled_out, logp = model.symbolic_translate(inp, greedy=False)",
    "\n",
    "dummy_translate = theano.function([], sampled_out, updates=model.auto_updates)",
    "\n",
    "\n",
    "print(\"\\nSymbolic_translate output:\\n\", sampled_out, logp)",
    "\n",
    "print(\"\\nSample translations:\\n\", dummy_translate())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# score logp(out | inp) with untrained input",
    "\n",
    "logp = model.symbolic_score(inp, out)",
    "\n",
    "dummy_score = theano.function([], logp)",
    "\n",
    "\n",
    "print(\"\\nSymbolic_score output:\\n\", logp)",
    "\n",
    "print(\"\\nLog-probabilities (clipped):\\n\", dummy_score()[:, :2, :5])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Prepare any operations you want here",
    "\n",
    "\n",
    "inp = T.imatrix(\"input tokens [batch,time]\")",
    "\n",
    "trans, _ = <build symbolic translations with greedy = True >",
    "\n",
    "translate_fun = theano.function([inp], trans, updates=model.auto_updates)",
    "\n",
    "\n",
    "\n",
    "def translate(lines):",
    "\n",
    "    \"\"\"",
    "\n",
    "    You are given a list of input lines. ",
    "\n",
    "    Make your neural network translate them.",
    "\n",
    "    :return: a list of output lines",
    "\n",
    "    \"\"\"",
    "\n",
    "    # Convert lines to a matrix of indices",
    "\n",
    "    lines_ix = <YOUR CODE >",
    "\n",
    "\n",
    "    # Compute translations in form of indices (call your function)",
    "\n",
    "    trans_ix = <YOUR CODE >",
    "\n",
    "\n",
    "    # Convert translations back into strings",
    "\n",
    "    return out_voc.to_lines(trans_ix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"Sample inputs:\", all_words[:3])",
    "\n",
    "print(\"Dummy translations:\", translate(all_words[:3]))",
    "\n",
    "\n",
    "assert trans.ndim == 2 and trans.dtype.startswith(",
    "\n",
    "    'int'), \"trans must be a tensor of integers (token ids)\"",
    "\n",
    "assert translate(all_words[:3]) == translate(",
    "\n",
    "    all_words[:3]), \"make sure translation is deterministic (use greedy=True and disable any noise layers)\"",
    "\n",
    "assert type(translate(all_words[:3])) is list and (type(translate(all_words[:1])[0]) is str or type(",
    "\n",
    "    translate(all_words[:1])[0]) is unicode), \"translate(lines) must return a sequence of strings!\"",
    "\n",
    "print(\"Tests passed!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Scoring function\n",
    "\n",
    "LogLikelihood is a poor estimator of model performance.\n",
    "* If we predict zero probability once, it shouldn't ruin entire model.\n",
    "* It is enough to learn just one translation if there are several correct ones.\n",
    "* What matters is how many mistakes model's gonna make when it translates!\n",
    "\n",
    "Therefore, we will use minimal Levenshtein distance. It measures how many characters do we need to add/remove/replace from model translation to make it perfect. Alternatively, one could use character-level BLEU/RougeL or other similar metrics.\n",
    "\n",
    "The catch here is that Levenshtein distance is not differentiable: it isn't even continuous. We can't train our neural network to maximize it by gradient descent."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import editdistance  # !pip install editdistance",
    "\n",
    "\n",
    "\n",
    "def get_distance(word, trans):",
    "\n",
    "    \"\"\"",
    "\n",
    "    A function that takes word and predicted translation",
    "\n",
    "    and evaluates (Levenshtein's) edit distance to closest correct translation",
    "\n",
    "    \"\"\"",
    "\n",
    "    references = word_to_translation[word]",
    "\n",
    "    assert len(references) != 0, \"wrong/unknown word\"",
    "\n",
    "    return min(editdistance.eval(trans, ref) for ref in references)",
    "\n",
    "\n",
    "\n",
    "def score(words, bsize=100):",
    "\n",
    "    \"\"\"a function that computes levenshtein distance for bsize random samples\"\"\"",
    "\n",
    "    assert isinstance(words, np.ndarray)",
    "\n",
    "\n",
    "    batch_words = np.random.choice(words, size=bsize, replace=False)",
    "\n",
    "    batch_trans = translate(batch_words)",
    "\n",
    "\n",
    "    distances = list(map(get_distance, batch_words, batch_trans))",
    "\n",
    "\n",
    "    return np.array(distances, dtype='float32')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# should be around 5-50 and decrease rapidly after training :)",
    "\n",
    "[score(test_words, 10).mean() for _ in range(5)]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 2: Supervised pre-training\n",
    "\n",
    "Here we define a function that trains our model through maximizing log-likelihood a.k.a. minimizing crossentropy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from agentnet.learning.generic import get_values_for_actions, get_mask_by_eos",
    "\n",
    "\n",
    "\n",
    "class llh_trainer:",
    "\n",
    "\n",
    "    # variable for correct answers",
    "\n",
    "    input_sequence = T.imatrix(\"input sequence [batch,time]\")",
    "\n",
    "    reference_answers = T.imatrix(\"reference translations [batch, time]\")",
    "\n",
    "\n",
    "    # Compute log-probabilities of all possible tokens at each step. Use model interface.",
    "\n",
    "    logprobs_seq = <YOUR CODE >",
    "\n",
    "\n",
    "    # compute mean crossentropy",
    "\n",
    "    crossentropy = - get_values_for_actions(logprobs_seq, reference_answers)",
    "\n",
    "\n",
    "    mask = get_mask_by_eos(T.eq(reference_answers, out_voc.eos_ix))",
    "\n",
    "\n",
    "    loss = T.sum(crossentropy * mask)/T.sum(mask)",
    "\n",
    "\n",
    "    # Build weight updates. Use model.weights to get all trainable params.",
    "\n",
    "    updates = <YOUR CODE >",
    "\n",
    "\n",
    "    train_step = theano.function(",
    "\n",
    "        [input_sequence, reference_answers], loss, updates=updates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Actually run training on minibatches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import random",
    "\n",
    "\n",
    "\n",
    "def sample_batch(words, word_to_translation, batch_size):",
    "\n",
    "    \"\"\"",
    "\n",
    "    sample random batch of words and random correct translation for each word",
    "\n",
    "    example usage:",
    "\n",
    "    batch_x,batch_y = sample_batch(train_words, word_to_translations,10)",
    "\n",
    "    \"\"\"",
    "\n",
    "    # choose words",
    "\n",
    "    batch_words = np.random.choice(words, size=batch_size)",
    "\n",
    "\n",
    "    # choose translations",
    "\n",
    "    batch_trans_candidates = list(map(word_to_translation.get, batch_words))",
    "\n",
    "    batch_trans = list(map(random.choice, batch_trans_candidates))",
    "\n",
    "\n",
    "    return inp_voc.to_matrix(batch_words), out_voc.to_matrix(batch_trans)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bx, by = sample_batch(train_words, word_to_translation, batch_size=3)",
    "\n",
    "print(\"Source:\")",
    "\n",
    "print(bx)",
    "\n",
    "print(\"Target:\")",
    "\n",
    "print(by)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import clear_output",
    "\n",
    "from tqdm import tqdm, trange  # or use tqdm_notebook,tnrange",
    "\n",
    "\n",
    "loss_history = []",
    "\n",
    "editdist_history = []",
    "\n",
    "\n",
    "for i in trange(25000):",
    "\n",
    "    loss = llh_trainer.train_step(",
    "\n",
    "        *sample_batch(train_words, word_to_translation, 32))",
    "\n",
    "    loss_history.append(loss)",
    "\n",
    "\n",
    "    if (i+1) % REPORT_FREQ == 0:",
    "\n",
    "        clear_output(True)",
    "\n",
    "        current_scores = score(test_words)",
    "\n",
    "        editdist_history.append(current_scores.mean())",
    "\n",
    "        plt.figure(figsize=(12, 4))",
    "\n",
    "        plt.subplot(131)",
    "\n",
    "        plt.title('train loss / traning time')",
    "\n",
    "        plt.plot(loss_history)",
    "\n",
    "        plt.grid()",
    "\n",
    "        plt.subplot(132)",
    "\n",
    "        plt.title('val score distribution')",
    "\n",
    "        plt.hist(current_scores, bins=20)",
    "\n",
    "        plt.subplot(133)",
    "\n",
    "        plt.title('val score / traning time')",
    "\n",
    "        plt.plot(editdist_history)",
    "\n",
    "        plt.grid()",
    "\n",
    "        plt.show()",
    "\n",
    "        print(\"llh=%.3f, mean score=%.3f\" %",
    "\n",
    "              (np.mean(loss_history[-10:]), np.mean(editdist_history[-10:])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for word in train_words[:10]:",
    "\n",
    "    print(\"%s -> %s\" % (word, translate([word])[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_scores = []",
    "\n",
    "for start_i in trange(0, len(test_words), 32):",
    "\n",
    "    batch_words = test_words[start_i:start_i+32]",
    "\n",
    "    batch_trans = translate(batch_words)",
    "\n",
    "    distances = list(map(get_distance, batch_words, batch_trans))",
    "\n",
    "    test_scores.extend(distances)",
    "\n",
    "\n",
    "print(\"Supervised test score:\", np.mean(test_scores))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Preparing for reinforcement learning (2 points)\n",
    "\n",
    "First we need to define loss function as a custom theano operation.\n",
    "\n",
    "The simple way to do so is\n",
    "```\n",
    "@theano.compile.as_op(input_types,output_type(s),infer_shape)\n",
    "def my_super_function(inputs):\n",
    "    return outputs\n",
    "```\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "__Your task__ is to implement `_compute_levenshtein` function that takes matrices of words and translations, along with input masks, then converts those to actual words and phonemes and computes min-levenshtein via __get_distance__ function above.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@theano.compile.as_op([T.imatrix]*2, [T.fvector], lambda _, shapes: [shapes[0][:1]])",
    "\n",
    "def _compute_levenshtein(words_ix, trans_ix):",
    "\n",
    "    \"\"\"",
    "\n",
    "    A custom theano operation that computes levenshtein loss for predicted trans.",
    "\n",
    "\n",
    "    Params:",
    "\n",
    "    - words_ix - a matrix of letter indices, shape=[batch_size,word_length]",
    "\n",
    "    - words_mask - a matrix of zeros/ones, ",
    "\n",
    "       1 means \"word is still not finished\"",
    "\n",
    "       0 means \"word has already finished and this is padding\"",
    "\n",
    "\n",
    "    - trans_mask - a matrix of output letter indices, shape=[batch_size,translation_length]",
    "\n",
    "    - trans_mask - a matrix of zeros/ones, similar to words_mask but for trans_ix",
    "\n",
    "\n",
    "\n",
    "    Please implement the function and make sure it passes tests from the next cell.",
    "\n",
    "\n",
    "    \"\"\"",
    "\n",
    "\n",
    "    # convert words to strings",
    "\n",
    "    words = <restore words(a list of strings) from words_ix >",
    "\n",
    "\n",
    "    assert type(words) is list and type(",
    "\n",
    "        words[0]) is str and len(words) == len(words_ix)",
    "\n",
    "\n",
    "    # convert translations to lists",
    "\n",
    "    translations = <restore trans(a list of lists of phonemes) from trans_ix",
    "\n",
    "\n",
    "    assert type(translations) is list and type(",
    "\n",
    "        translations[0]) is str and len(translations) == len(trans_ix)",
    "\n",
    "\n",
    "    # computes levenstein distances. can be arbitrary python code.",
    "\n",
    "    distances = <apply get_distance to each pair of[words, translations] >",
    "\n",
    "\n",
    "    assert type(distances) in (list, tuple, np.ndarray) and len(",
    "\n",
    "        distances) == len(words_ix)",
    "\n",
    "\n",
    "    distances = np.array(list(distances), dtype='float32')",
    "\n",
    "    return distances",
    "\n",
    "\n",
    "\n",
    "# forbid gradient",
    "\n",
    "from theano.gradient import disconnected_grad",
    "\n",
    "\n",
    "\n",
    "def compute_levenshtein(*args):",
    "\n",
    "    return disconnected_grad(_compute_levenshtein(*[arg.astype('int32') for arg in args]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simple test suite to make sure your implementation is correct. Hint: if you run into any bugs, feel free to use print from inside _compute_levenshtein."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# test suite",
    "\n",
    "# sample random batch of (words, correct trans, wrong trans)",
    "\n",
    "batch_words = np.random.choice(train_words, size=100)",
    "\n",
    "batch_trans = list(map(random.choice, map(",
    "\n",
    "    word_to_translation.get, batch_words)))",
    "\n",
    "batch_trans_wrong = np.random.choice(all_translations, size=100)",
    "\n",
    "\n",
    "batch_words_ix = T.constant(inp_voc.to_matrix(batch_words))",
    "\n",
    "batch_trans_ix = T.constant(out_voc.to_matrix(batch_trans))",
    "\n",
    "batch_trans_wrong_ix = T.constant(out_voc.to_matrix(batch_trans_wrong))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# assert compute_levenshtein is zero for ideal translations",
    "\n",
    "correct_answers_score = compute_levenshtein(",
    "\n",
    "    batch_words_ix, batch_trans_ix).eval()",
    "\n",
    "\n",
    "assert np.all(correct_answers_score ==",
    "\n",
    "              0), \"a perfect translation got nonzero levenshtein score!\"",
    "\n",
    "\n",
    "print(\"Everything seems alright!\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# assert compute_levenshtein matches actual scoring function",
    "\n",
    "wrong_answers_score = compute_levenshtein(",
    "\n",
    "    batch_words_ix, batch_trans_wrong_ix).eval()",
    "\n",
    "\n",
    "true_wrong_answers_score = np.array(",
    "\n",
    "    list(map(get_distance, batch_words, batch_trans_wrong)))",
    "\n",
    "\n",
    "assert np.all(wrong_answers_score ==",
    "\n",
    "              true_wrong_answers_score), \"for some word symbolic levenshtein is different from actual levenshtein distance\"",
    "\n",
    "\n",
    "print(\"Everything seems alright!\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Once you got it working...\n",
    "\n",
    "\n",
    "* You may now want to __remove/comment asserts__ from function code for a slight speed-up.\n",
    "\n",
    "* There's a more detailed tutorial on custom theano ops here: [docs](http://deeplearning.net/software/theano/extending/extending_theano.html), [example](https://gist.github.com/justheuristic/9f4ffef6162a8089c3260fc3bbacbf46)."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Self-critical policy gradient (2 points)\n",
    "\n",
    "In this section you'll implement algorithm called self-critical sequence training (here's an [article](https://arxiv.org/abs/1612.00563)).\n",
    "\n",
    "The algorithm is a vanilla policy gradient with a special baseline. \n",
    "\n",
    "$$ \\nabla J = E_{x \\sim p(s)} E_{y \\sim \\pi(y|x)} \\nabla log \\pi(y|x) \\cdot (R(x,y) - b(x)) $$\n",
    "\n",
    "Here reward R(x,y) is a __negative levenshtein distance__ (since we minimize it). The baseline __b(x)__ represents how well model fares on word __x__.\n",
    "\n",
    "In practice, this means that we compute baseline as a score of greedy translation, $b(x) = R(x,y_{greedy}(x)) $.\n",
    "![img](https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/scheme.png)\n",
    "\n",
    "Luckily, we already obtained the required outputs: `model.greedy_translations, model.greedy_mask` and we only need to compute levenshtein using `compute_levenshtein` function.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "class trainer:",
    "\n",
    "\n",
    "    input_sequence = T.imatrix(\"input tokens [batch,time]\")",
    "\n",
    "\n",
    "    # use model to __sample__ symbolic translations given input_sequence",
    "\n",
    "    sample_translations, sample_logp = <your code here >",
    "\n",
    "    auto_updates = model.auto_updates",
    "\n",
    "    # use model to __greedy__ symbolic translations given input_sequence",
    "\n",
    "    greedy_translations, greedy_logp = <your code here >",
    "\n",
    "    greedy_auto_updates = model.auto_updates",
    "\n",
    "\n",
    "    # Note: you can use model.symbolic_translate(...,unroll_scan=True,max_len=MAX_OUTPUT_LENGTH)",
    "\n",
    "    # to run much faster at a cost of longer compilation",
    "\n",
    "\n",
    "    rewards = - compute_levenshtein(input_sequence, sample_translations)",
    "\n",
    "\n",
    "    baseline = <compute __negative__ levenshtein for greedy mode >",
    "\n",
    "\n",
    "    # compute advantage using rewards and baseline",
    "\n",
    "    advantage = <your code - compute advantage >",
    "\n",
    "\n",
    "    # compute log_pi(a_t|s_t), shape = [batch, seq_length]",
    "\n",
    "    logprobs_phoneme = get_values_for_actions(sample_logp, sample_translations)",
    "\n",
    "\n",
    "    # policy gradient",
    "\n",
    "    J = logprobs_phoneme*advantage[:, None]",
    "\n",
    "\n",
    "    mask = get_mask_by_eos(T.eq(sample_translations, out_voc.eos_ix))",
    "\n",
    "    loss = - T.sum(J*mask) / T.sum(mask)",
    "\n",
    "\n",
    "    # regularize with negative entropy. Don't forget the sign!",
    "\n",
    "    # note: for entropy you need probabilities for all tokens (sample_logp), not just phoneme_logprobs",
    "\n",
    "    entropy = <compute entropy matrix of shape[batch, seq_length], H = -sum(p*log_p), don't forget the sign!>",
    "\n",
    "\n",
    "    assert entropy.ndim == 2, \"please make sure elementwise entropy is of shape [batch,time]\"",
    "\n",
    "\n",
    "    loss -= 0.01*T.sum(entropy*mask) / T.sum(mask)",
    "\n",
    "\n",
    "    # compute weight updates, clip by norm",
    "\n",
    "    grads = T.grad(loss, model.weights)",
    "\n",
    "    grads = lasagne.updates.total_norm_constraint(grads, 50)",
    "\n",
    "\n",
    "    updates = lasagne.updates.adam(grads, model.weights, learning_rate=1e-5)",
    "\n",
    "\n",
    "    train_step = theano.function([input_sequence], loss,",
    "\n",
    "                                 updates=auto_updates+greedy_auto_updates+updates)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Policy gradient training\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in trange(100000):",
    "\n",
    "    loss_history.append(",
    "\n",
    "        trainer.train_step(sample_batch(",
    "\n",
    "            train_words, word_to_translation, 32)[0])",
    "\n",
    "    )",
    "\n",
    "\n",
    "    if (i+1) % REPORT_FREQ == 0:",
    "\n",
    "        clear_output(True)",
    "\n",
    "        current_scores = score(test_words)",
    "\n",
    "        editdist_history.append(current_scores.mean())",
    "\n",
    "        plt.figure(figsize=(8, 4))",
    "\n",
    "        plt.subplot(121)",
    "\n",
    "        plt.title('val score distribution')",
    "\n",
    "        plt.hist(current_scores, bins=20)",
    "\n",
    "        plt.subplot(122)",
    "\n",
    "        plt.title('val score / traning time')",
    "\n",
    "        plt.plot(editdist_history)",
    "\n",
    "        plt.grid()",
    "\n",
    "        plt.show()",
    "\n",
    "        print(\"J=%.3f, mean score=%.3f\" %",
    "\n",
    "              (np.mean(loss_history[-10:]), np.mean(editdist_history[-10:])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.translate(\"EXAMPLE;\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for word in train_words[:10]:",
    "\n",
    "    print(\"%s -> %s\" % (word, translate([word])[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_scores = []",
    "\n",
    "for start_i in trange(0, len(test_words), 32):",
    "\n",
    "    batch_words = test_words[start_i:start_i+32]",
    "\n",
    "    batch_trans = translate(batch_words)",
    "\n",
    "    distances = list(map(get_distance, batch_words, batch_trans))",
    "\n",
    "    test_scores.extend(distances)",
    "\n",
    "print(\"Supervised test score:\", np.mean(test_scores))",
    "\n",
    "\n",
    "# ^^ If you get Out Of Memory, please replace this with batched computation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Step 6: Make it actually work (5++ pts)\n",
    "<img src=https://github.com/yandexdataschool/Practical_RL/raw/master/yet_another_week/_resource/do_something_scst.png width=400>\n",
    "\n",
    "\n",
    "In this section we want you to finally __restart with EASY_MODE=False__ and experiment to find a good model/curriculum for that task.\n",
    "\n",
    "We recommend you to start with the following architecture\n",
    "\n",
    "```\n",
    "encoder---decoder\n",
    "\n",
    "           P(y|h)\n",
    "             ^\n",
    " LSTM  ->   LSTM\n",
    "  ^          ^\n",
    " biLSTM  ->   LSTM\n",
    "  ^          ^\n",
    "input       y_prev\n",
    "```\n",
    "\n",
    "__Note:__ you can fit all 4 state tensors of both LSTMs into a in a single state - just assume that it contains, for example, [h0, c0, h1, c1] - pack it in encode and update in decode.\n",
    "\n",
    "\n",
    "Here are some cool ideas on what you can do then.\n",
    "\n",
    "__General tips & tricks:__\n",
    "* In some tensorflow versions and for some layers, it is required that each rnn/gru/lstm cell gets it's own `tf.variable_scope(unique_name, reuse=False)`.\n",
    "  * Otherwise it will complain about wrong tensor sizes because it tries to reuse weights from one rnn to the other.\n",
    "* You will likely need to adjust pre-training time for such a network.\n",
    "* Supervised pre-training may benefit from clipping gradients somehow.\n",
    "* SCST may indulge a higher learning rate in some cases and changing entropy regularizer over time.\n",
    "* It's often useful to save pre-trained model parameters to not re-train it every time you want new policy gradient parameters. \n",
    "* When leaving training for nighttime, try setting REPORT_FREQ to a larger value (e.g. 500) not to waste time on it.\n",
    "\n",
    "__Formal criteria:__\n",
    "To get 5 points we want you to build an architecture that:\n",
    "* _doesn't consist of single GRU_\n",
    "* _works better_ than single GRU baseline. \n",
    "* We also want you to provide either learning curve or trained model, preferably both\n",
    "* ... and write a brief report or experiment log describing what you did and how it fared.\n",
    "\n",
    "### Attention\n",
    "There's more than one way to connect decoder to encoder\n",
    "  * __Vanilla:__ layer_i of encoder last state goes to layer_i of decoder initial state\n",
    "  * __Every tick:__ feed encoder last state _on every iteration_ of decoder.\n",
    "  * __Attention:__ allow decoder to \"peek\" at one (or several) positions of encoded sequence on every tick.\n",
    "  \n",
    "The most effective (and cool) of those is, of course, attention.\n",
    "You can read more about attention [in this nice blog post](https://distill.pub/2016/augmented-rnns/). The easiest way to begin is to use \"soft\" attention with \"additive\" or \"dot-product\" intermediate layers.\n",
    "\n",
    "__Tips__\n",
    "* Model usually generalizes better if you no longer allow decoder to see final encoder state\n",
    "* Once your model made it through several epochs, it is a good idea to visualize attention maps to understand what your model has actually learned\n",
    "\n",
    "* There's more stuff [here](https://github.com/yandexdataschool/Practical_RL/blob/master/week8_scst/bonus.ipynb)\n",
    "* If you opted for hard attention, we recommend [gumbel-softmax](https://blog.evjang.com/2016/11/tutorial-categorical-variational.html) instead of sampling. Also please make sure soft attention works fine before you switch to hard.\n",
    "\n",
    "### UREX\n",
    "* This is a way to improve exploration in policy-based settings. The main idea is that you find and upweight under-appreciated actions.\n",
    "* Here's [video](https://www.youtube.com/watch?v=fZNyHoXgV7M&feature=youtu.be&t=3444)\n",
    " and an [article](https://arxiv.org/abs/1611.09321).\n",
    "* You may want to reduce batch size 'cuz UREX requires you to sample multiple times per source sentence.\n",
    "* Once you got it working, try using experience replay with importance sampling instead of (in addition to) basic UREX.\n",
    "\n",
    "### Some additional ideas:\n",
    "* (advanced deep learning) It may be a good idea to first train on small phrases and then adapt to larger ones (a.k.a. training curriculum).\n",
    "* (advanced nlp) You may want to switch from raw utf8 to something like unicode or even syllables to make task easier.\n",
    "* (advanced nlp) Since hebrew words are written __with vowels omitted__, you may want to use a small Hebrew vowel markup dataset at `he-pron-wiktionary.txt`.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert not EASY_MODE, \"make sure you set EASY_MODE = False at the top of the notebook.\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`[your report/log here or anywhere you please]`"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__Contributions:__ This notebook is brought to you by\n",
    "* Yandex [MT team](https://tech.yandex.com/translate/)\n",
    "* Denis Mazur ([DeniskaMazur](https://github.com/DeniskaMazur)), Oleg Vasilev ([Omrigan](https://github.com/Omrigan/)), Dmitry Emelyanenko ([TixFeniks](https://github.com/tixfeniks)) and Fedor Ratnikov ([justheuristic](https://github.com/justheuristic/))\n",
    "* Dataset is parsed from [Wiktionary](https://en.wiktionary.org), which is under CC-BY-SA and GFDL licenses.\n"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python",
   "pygments_lexer": "ipython3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
